import torch
import torch.nn as nn
import torchvision.models
import torchvision.models.segmentation as sg
import network.resnet as resnet

class DeepLabV3(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = sg.deeplabv3_resnet50(num_classes=2)

    def forward(self, x):
        layer1= self.backbone(x)['out']

        return layer1